{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp problem_types.seq_tag\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sequence Labeling(seq_tag)\n",
    "\n",
    "This module includes neccessary part to register sequence labeling problem type."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "from typing import List\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (create_dummy_if_empty,\n",
    "                                      empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling)\n",
    "from m3tl.special_tokens import PREDICT, TRAIN\n",
    "from m3tl.utils import (LabelEncoder, get_label_encoder_save_path, get_phase,\n",
    "                        need_make_label_encoder, variable_summaries)\n",
    "from tensorflow_addons.layers.crf import CRF\n",
    "from tensorflow_addons.text.crf import crf_log_likelihood\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Top Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "class SequenceLabel(tf.keras.Model):\n",
    "    def __init__(self, params: BaseParams, problem_name: str):\n",
    "        super(SequenceLabel, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "        num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')\n",
    "        self.dense = tf.keras.layers.Dense(num_classes, activation=None)\n",
    "\n",
    "        self.dropout = tf.keras.layers.Dropout(1-params.dropout_keep_prob)\n",
    "\n",
    "        if self.params.crf:\n",
    "            self.crf = CRF(num_classes)\n",
    "            self.metric_fn = tf.keras.metrics.Accuracy(\n",
    "                name='{}_acc'.format(self.problem_name)\n",
    "            )\n",
    "        else:\n",
    "            self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "                name='{}_acc'.format(self.problem_name))\n",
    "\n",
    "    def return_crf_result(self, labels: tf.Tensor, logits: tf.Tensor, mode: str, input_mask: tf.Tensor):\n",
    "        input_mask.set_shape([None, None])\n",
    "        logits = create_dummy_if_empty(logits)\n",
    "        input_mask = create_dummy_if_empty(input_mask)\n",
    "        viterbi_decoded, potentials, sequence_length, chain_kernel = self.crf(\n",
    "            logits, input_mask)\n",
    "        if mode != PREDICT:\n",
    "            loss = -crf_log_likelihood(potentials,\n",
    "                                       labels, sequence_length, chain_kernel)[0]\n",
    "            loss = tf.reduce_mean(loss)\n",
    "            loss = nan_loss_handling(loss)\n",
    "            self.add_loss(loss)\n",
    "            acc = self.metric_fn(\n",
    "                labels, viterbi_decoded, sample_weight=input_mask)\n",
    "            self.add_metric(acc)\n",
    "\n",
    "        # make the crf prediction has the same shape as non-crf prediction\n",
    "        return tf.one_hot(viterbi_decoded, name='%s_predict' % self.problem_name, depth=self.params.num_classes[self.problem_name])\n",
    "\n",
    "    def call(self, inputs):\n",
    "        mode = get_phase()\n",
    "        training = (mode == TRAIN)\n",
    "        feature, hidden_feature = inputs\n",
    "        hidden_feature = hidden_feature['seq']\n",
    "        if mode != PREDICT:\n",
    "            labels = feature['{}_label_ids'.format(self.problem_name)]\n",
    "            # sometimes the length of labels dose not equal to length of inputs\n",
    "            # that's caused by tf.data.experimental.bucket_by_sequence_length in multi problem scenario\n",
    "            pad_len = tf.shape(input=hidden_feature)[\n",
    "                1] - tf.shape(input=labels)[1]\n",
    "\n",
    "            # top, bottom, left, right\n",
    "            pad_tensor = [[0, 0], [0, pad_len]]\n",
    "            labels = tf.pad(tensor=labels, paddings=pad_tensor)\n",
    "\n",
    "        else:\n",
    "            labels = None\n",
    "        hidden_feature = self.dropout(hidden_feature, training)\n",
    "\n",
    "        if self.params.crf:\n",
    "            return self.return_crf_result(labels, hidden_feature, mode, feature['model_input_mask'])\n",
    "\n",
    "        logits = self.dense(hidden_feature)\n",
    "\n",
    "        if self.params.detail_log:\n",
    "            for weight_variable in self.weights:\n",
    "                variable_summaries(weight_variable, self.problem_name)\n",
    "\n",
    "        if mode != PREDICT:\n",
    "            loss = empty_tensor_handling_loss(\n",
    "                labels, logits,\n",
    "                tf.keras.losses.sparse_categorical_crossentropy)\n",
    "            self.add_loss(loss)\n",
    "            acc = self.metric_fn(\n",
    "                labels, logits, sample_weight=feature['model_input_mask'])\n",
    "            self.add_metric(acc)\n",
    "        return tf.nn.softmax(\n",
    "            logits, name='%s_predict' % self.problem_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_top_layer(SequenceLabel, problem='weibo_fake_ner', params=params, sample_features=one_batch, hidden_dim=hidden_dim)\n",
    "params.crf = False\n",
    "test_top_layer(SequenceLabel, problem='weibo_fake_ner', params=params, sample_features=one_batch, hidden_dim=hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get or make label encoder function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "\n",
    "\n",
    "def seq_tag_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "\n",
    "    le_path = get_label_encoder_save_path(params=params, problem=problem)\n",
    "    label_encoder = LabelEncoder()\n",
    "\n",
    "    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):\n",
    "        # fit and save label encoder\n",
    "        label_list = [item for sublist in label_list for item in sublist] + ['[PAD]']\n",
    "        label_encoder.fit(label_list)\n",
    "        label_encoder.dump(le_path)\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))\n",
    "    else:\n",
    "        label_encoder.load(le_path)\n",
    "\n",
    "    return label_encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label handing function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def seq_tag_label_handling(tokenized_dict: dict, target: list, pad_token: str) -> tuple:\n",
    "    special_token_mask = tokenized_dict['special_tokens_mask']\n",
    "    del tokenized_dict['special_tokens_mask']\n",
    "\n",
    "    # handle truncation\n",
    "    if tokenized_dict.get('num_truncated_tokens', 0) > 0:\n",
    "        target = target[:len(target) - tokenized_dict['num_truncated_tokens']]\n",
    "\n",
    "    processed_target = []\n",
    "    for m in special_token_mask:\n",
    "        # 0 is special tokens, 1 is tokens\n",
    "        if m == 1:\n",
    "            # add pad\n",
    "            processed_target.append(pad_token)\n",
    "        else:\n",
    "            processed_target.append(target.pop(0))\n",
    "    return processed_target, tokenized_dict\n",
    "\n",
    "def seq_tag_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "\n",
    "    if kwargs is None:\n",
    "        return {}\n",
    "\n",
    "    if 'tokenized_inputs' not in kwargs:\n",
    "        return {}\n",
    "\n",
    "    if kwargs['tokenized_inputs'] is None:\n",
    "        return {}\n",
    "\n",
    "    tokenized_inputs = kwargs['tokenized_inputs']\n",
    "    # target should align with input_ids\n",
    "    target, tokenized_inputs = seq_tag_label_handling(\n",
    "        tokenized_inputs, target, '[PAD]')\n",
    "\n",
    "    if len(target) != len(tokenized_inputs['input_ids']):\n",
    "        raise ValueError(\n",
    "            'Length is different for seq tag problem, inputs: {}'.format(tokenizer.decode(tokenized_inputs['input_ids'])))\n",
    "\n",
    "    label_id = label_encoder.transform(target).tolist()\n",
    "    label_id = [np.int32(i) for i in label_id]\n",
    "    return label_id, None\n",
    "\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
